{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "llx_cGpoAyAq"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "5MAYU_6KA0Kt"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "35lZ8kr3UcsB"
      },
      "source": [
        "# Data augmentation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/images/data_augmentation\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/data_augmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/data_augmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/images/data_augmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "BZP72A6eFw74"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This tutorial demonstrates manual image manipulations and augmentation using `tf.image`.\n",
        "\n",
        "Data augmentation is a common technique to improve results and avoid overfitting, see [Overfitting and Underfitting](../keras/overfit_and_underfit.ipynb) for others."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8sZIVqk7HvnC"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "3TS4vBJBd8jY"
      },
      "outputs": [],
      "source": [
        "!pip install git+https://github.com/tensorflow/docs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "rdP8EQbPsyRA"
      },
      "outputs": [],
      "source": [
        "import urllib\n",
        "\n",
        "import tensorflow as tf\n",
        "from tensorflow.keras.datasets import mnist\n",
        "from tensorflow.keras import layers\n",
        "AUTOTUNE = tf.data.experimental.AUTOTUNE\n",
        "\n",
        "import tensorflow_docs as tfdocs\n",
        "import tensorflow_docs.plots\n",
        "\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "import PIL.Image\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib as mpl\n",
        "mpl.rcParams['figure.figsize'] = (12, 5)\n",
        "\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "__EuXwM-8uth"
      },
      "source": [
        "Let's check the data augmentation features on an image and then augment a whole dataset later to train a model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "frBSdODBLOOI"
      },
      "source": [
        "Download [this image](https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg), by Von.grzanka, for augmentation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "s5ThIwG8KqzI"
      },
      "outputs": [],
      "source": [
        "image_path = tf.keras.utils.get_file(\"cat.jpg\", \"https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg\")\n",
        "PIL.Image.open(image_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-Ec3bGonGDCF"
      },
      "source": [
        "Read and decode the image to tensor format."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "cdCoB8b-uZjf"
      },
      "outputs": [],
      "source": [
        "image_string=tf.io.read_file(image_path)\n",
        "image=tf.image.decode_jpeg(image_string,channels=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "isGwyT0386yi"
      },
      "source": [
        "A function to visualize and compare the original and augmented image side by side."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "FKnRfw2dvyql"
      },
      "outputs": [],
      "source": [
        "def visualize(original, augmented):\n",
        "  fig = plt.figure()\n",
        "  plt.subplot(1,2,1)\n",
        "  plt.title('Original image')\n",
        "  plt.imshow(original)\n",
        "\n",
        "  plt.subplot(1,2,2)\n",
        "  plt.title('Augmented image')\n",
        "  plt.imshow(augmented)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "jYLzpEOhGqWY"
      },
      "source": [
        "## Augment a single image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8IiXghY99Bo6"
      },
      "source": [
        "### Flipping the image\n",
        "Flip the image either vertically or horizontally."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "X14VjLlFxnvZ"
      },
      "outputs": [],
      "source": [
        "flipped = tf.image.flip_left_right(image)\n",
        "visualize(image, flipped)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ObsvSmu99MfC"
      },
      "source": [
        "### Grayscale the image\n",
        "Grayscale an image."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "mnqQA2ubyo6O"
      },
      "outputs": [],
      "source": [
        "grayscaled = tf.image.rgb_to_grayscale(image)\n",
        "visualize(image, tf.squeeze(grayscaled))\n",
        "plt.colorbar()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "juI4A4HF9gYc"
      },
      "source": [
        "### Saturate the image\n",
        "Saturate an image by providing a saturation factor."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "tiTUhw-gzCJW"
      },
      "outputs": [],
      "source": [
        "saturated = tf.image.adjust_saturation(image, 3)\n",
        "visualize(image, saturated)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "E82CqomP9qcR"
      },
      "source": [
        "### Change image brightness\n",
        "Change the brightness of image by providing a brightness factor."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "05dA6uEtzfyd"
      },
      "outputs": [],
      "source": [
        "bright = tf.image.adjust_brightness(image, 0.4)\n",
        "visualize(image, bright)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5_0kMbmS91x6"
      },
      "source": [
        "### Rotate the image\n",
        "Rotate an image by 90 degrees."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "edNoQzhszxo8"
      },
      "outputs": [],
      "source": [
        "rotated = tf.image.rot90(image)\n",
        "visualize(image, rotated)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "bomBnFWp9895"
      },
      "source": [
        "### Center crop the image\n",
        "Crop the image from center upto the image part you desire."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "fvgz_6t21dq2"
      },
      "outputs": [],
      "source": [
        "cropped = tf.image.central_crop(image, central_fraction=0.5)\n",
        "visualize(image,cropped)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8W5E_c7o-H96"
      },
      "source": [
        "See the `tf.image` reference for details about available augmentation options."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "92lBGZSQ-1Tx"
      },
      "source": [
        "## Augment a dataset and train a model with it"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lrDez4xIX9Ss"
      },
      "source": [
        "Train a model on an augmented dataset.\n",
        "\n",
        "Note: The problem solved here is somewhat artificial. It trains a densely connected network to be shift invariant by jittering the input images. It's much more efficient to use convolutional layers instead."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "mazlEonS_gTR"
      },
      "outputs": [],
      "source": [
        "dataset, info =  tfds.load('mnist', as_supervised=True, with_info=True)\n",
        "train_dataset, test_dataset = dataset['train'], dataset['test']\n",
        "\n",
        "num_train_examples= info.splits['train'].num_examples"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "011caOa0YCz5"
      },
      "source": [
        "Write a function to augment the images. Map it over the the dataset. This returns a dataset that augments the data on the fly."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "3oaSV5QcDS8p"
      },
      "outputs": [],
      "source": [
        "def convert(image, label):\n",
        "  image = tf.image.convert_image_dtype(image, tf.float32) # Cast and normalize the image to [0,1]\n",
        "  return image, label\n",
        "\n",
        "def augment(image,label):\n",
        "  image,label = convert(image, label)\n",
        "  image = tf.image.convert_image_dtype(image, tf.float32) # Cast and normalize the image to [0,1]\n",
        "  image = tf.image.resize_with_crop_or_pad(image, 34, 34) # Add 6 pixels of padding\n",
        "  image = tf.image.random_crop(image, size=[28, 28, 1]) # Random crop back to 28x28\n",
        "  image = tf.image.random_brightness(image, max_delta=0.5) # Random brightness\n",
        "\n",
        "  return image,label"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "xNROtM5uqSjg"
      },
      "outputs": [],
      "source": [
        "BATCH_SIZE = 64\n",
        "# Only use a subset of the data so it's easier to overfit, for this tutorial\n",
        "NUM_EXAMPLES = 2048"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "dEq0VnUIy-8l"
      },
      "source": [
        "Create the augmented dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "jWJpV6JOqN_7"
      },
      "outputs": [],
      "source": [
        "augmented_train_batches = (\n",
        "    train_dataset\n",
        "    # Only train on a subset, so you can quickly see the effect.\n",
        "    .take(NUM_EXAMPLES)\n",
        "    .cache()\n",
        "    .shuffle(num_train_examples//4)\n",
        "    # The augmentation is added here.\n",
        "    .map(augment, num_parallel_calls=AUTOTUNE)\n",
        "    .batch(BATCH_SIZE)\n",
        "    .prefetch(AUTOTUNE)\n",
        ") "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "PGm5X_77zE3q"
      },
      "source": [
        "And a non-augmented one for comparison."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "td4yU920qOgU"
      },
      "outputs": [],
      "source": [
        "non_augmented_train_batches = (\n",
        "    train_dataset\n",
        "    # Only train on a subset, so you can quickly see the effect.\n",
        "    .take(NUM_EXAMPLES)\n",
        "    .cache()\n",
        "    .shuffle(num_train_examples//4)\n",
        "    # No augmentation.\n",
        "    .map(convert, num_parallel_calls=AUTOTUNE)\n",
        "    .batch(BATCH_SIZE)\n",
        "    .prefetch(AUTOTUNE)\n",
        ") "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zkLkgEudzKfc"
      },
      "source": [
        "Setup the validation dataset. This doesn't change whether or not you're using the augmentation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "6eqKD4zOqpvE"
      },
      "outputs": [],
      "source": [
        "validation_batches = (\n",
        "    test_dataset\n",
        "    .map(convert, num_parallel_calls=AUTOTUNE)\n",
        "    .batch(2*BATCH_SIZE)\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Yi9TIwR-ZIOi"
      },
      "source": [
        "Create and compile the model. The model is a two layered, fully-connected neural network without convolution."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "hHhkA4Q0CsHx"
      },
      "outputs": [],
      "source": [
        "def make_model():\n",
        "  model = tf.keras.Sequential([\n",
        "      layers.Flatten(input_shape=(28, 28, 1)),\n",
        "      layers.Dense(4096, activation='relu'),\n",
        "      layers.Dense(4096, activation='relu'),\n",
        "      layers.Dense(10)\n",
        "  ])\n",
        "  model.compile(optimizer = 'adam',\n",
        "                loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "                metrics=['accuracy'])\n",
        "  return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "P0rciou3ZWwy"
      },
      "source": [
        "Train the model, **without** augmentation:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "z8X8CpqvNhG9"
      },
      "outputs": [],
      "source": [
        "model_without_aug = make_model()\n",
        "\n",
        "no_aug_history = model_without_aug.fit(non_augmented_train_batches, epochs=50, validation_data=validation_batches)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yH3LUd1prFWe"
      },
      "source": [
        "Train it again with augmentation:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "z4RqS9FWrEiS"
      },
      "outputs": [],
      "source": [
        "model_with_aug = make_model()\n",
        "\n",
        "aug_history = model_with_aug.fit(augmented_train_batches, epochs=50, validation_data=validation_batches)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "UEqeeNsHZaC5"
      },
      "source": [
        "## Conclusion:\n",
        "\n",
        "In this example the augmented model converges to an accuracy ~95% on validation set. This is slightly higher (+1%) than the model trained without data augmentation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ZnkliKKm0VVw"
      },
      "outputs": [],
      "source": [
        "plotter = tfdocs.plots.HistoryPlotter()\n",
        "plotter.plot({\"Augmented\": aug_history, \"Non-Augmented\": no_aug_history}, metric = \"accuracy\")\n",
        "plt.title(\"Accuracy\")\n",
        "plt.ylim([0.75,1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cAaUjNhI3j7k"
      },
      "source": [
        "In terms of loss, the non-augmented model is obviously in the overfitting regime. The augmented model, while a few epoch slower, is still training correctly and clearly not overfitting."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "1-dTZpc4zq0-"
      },
      "outputs": [],
      "source": [
        "plotter = tfdocs.plots.HistoryPlotter()\n",
        "plotter.plot({\"Augmented\": aug_history, \"Non-Augmented\": no_aug_history}, metric = \"loss\")\n",
        "plt.title(\"Loss\")\n",
        "plt.ylim([0,1])"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "data_augmentation.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
